# Kernel: hgemm_tn_128x128

# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    addr_zero  : 4x<128*8*4>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_Rand[0]   : c[0x0][0x140]
    param_Rand[1]   : c[0x0][0x144]
    param_A[0]      : c[0x0][0x148]
    param_A[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_C[0]      : c[0x0][0x158]
    param_C[1]      : c[0x0][0x15c]
    param_lda8      : c[0x0][0x160]
    param_ldb8      : c[0x0][0x164]
    param_ldc       : c[0x0][0x168]
    param_m         : c[0x0][0x16c]
    param_n         : c[0x0][0x170]
    param_k         : c[0x0][0x174]
    param_alpha     : c[0x0][0x178]
    param_beta      : c[0x0][0x17c]
    param_flags     : c[0x0][0x180]
    param_ldaz      : c[0x0][0x184]
    param_ldbz      : c[0x0][0x188]
    param_ldcz      : c[0x0][0x18c]
    param_loops     : c[0x0][0x190]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    64-95   ~ lda, ldb, ldaz, ldbz, tid1, tid7, tid31, tid128, tid15, tbid, tidX, blk, dimA, flag, x<1-3>, y<1-3>

    0-63    : czero<00-63>

     3, 2,11,10,19,18,27,26 : cx<0-7>y0
     7, 6,15,14,23,22,31,30 : cx<0-7>y1
     1, 0, 9, 8,17,16,25,24 : cx<0-7>y2
     5, 4,13,12,21,20,29,28 : cx<0-7>y3
    35,34,43,42,51,50,59,58 : cx<0-7>y4
    39,38,47,46,55,54,63,62 : cx<0-7>y5
    33,32,41,40,49,48,57,56 : cx<0-7>y6
    37,36,45,44,53,52,61,60 : cx<0-7>y7

    64-79   : j0Ay<0-7>, j0Bx<0-7>
    80-95   : j1Ay<0-7>, j1Bx<0-7>

    96-103  : loadA<0-3>, loadB<0-3>

    104-107 : trackA<0-1>, trackB<0-1>

    108-125 ~ writeS, k, txa, txb, tidY, ta, tb, loop
    119-125 ~ readAs, readBs, tid, blkA, blkB, blkZ, seed
    126-127 : Rand<0-1>

    64-75   ~ ldc, ldcz, ci, xmad_c, clk_shf1, clk_shf2, tid_31, tid_96, tid_128

    64-75   : c<0-7>, d3, d2, d1, d0
    76-85   : C00y<0-1>, C04y<0-1>, C08y<0-1>, C12y<0-1>
    86-107  ~ lfsr<0-2>, ldc1, ldc4, ldc60, writeCs, readCs, cx<00|64>, cy<00|04|08|12>, alpha, beta

    108-125 ~ exp<0-3>, rand<0-3>, lfsr<0-2>_1, lfsr<0-2>_2, flags

</REGISTER_MAPPING>

--:-:1:-:1      S2R tid,  SR_TID.X;
--:-:2:-:1      S2R blkA, SR_CTAID.Y;
--:-:3:-:1      S2R blkB, SR_CTAID.Z;
--:-:4:-:1      S2R blkZ, SR_CTAID.X;

<SCHEDULE_BLOCK>
--:-:-:-:1      MOV k,    param_k;
--:-:-:-:1      MOV loop, RZ;

--:-:-:-:1      STS.128 [addr_zero], RZ;
<CODE>
        join('', map sprintf("--:-:-:-:1      LDS.U.128 czero%02d, [addr_zero];\n", $_ * 4), 0..15);
</CODE>

// Grab a seed for this thread
// (blkB*gridDimA*256 + blkA*256 + tid) & (1024*256 - 1)
--:-:-:-:1      MOV flag, param_flags;
--:-:-:-:1      LOP.AND.NZ P4, RZ, flag, 0x1;
--:-:-:-:1      MOV dimA, gridDimA;
03:-:-:-:1      ISCADD tbid, blkA, tid, 8;
04:-:-:-:1      XMAD.U16.U16 dimA, blkB, dimA, RZ;
--:-:-:-:1      ISCADD tbid, dimA, tbid, 8;
--:-:-:-:1      LOP.AND seed, tbid, 1x<2048*32 - 1>;
--:-:-:-:1      LEA      Rand0.CC, seed, param_Rand[0],     0x2;
--:-:-:-:1      LEA.HI.X Rand1,    seed, param_Rand[1], RZ, 0x2;
--:-:-:-:1  @P4 LDG.E.CS seed, [Rand];

--:-:-:-:1      LOP.AND tid1,   tid,  1;
--:-:-:-:1      LOP.AND tid128, tid,  128;

// tidX = (tid & 31) << 2
// tidY = (tid >> 5) & 7
01:-:-:-:1      LOP.AND tid31,  tid,  31;
--:-:-:-:1      SHL     tidX,   tid31, 2;
--:-:-:-:1      BFE.U32 tidY,   tid,  0x305; // 3 bits at position 5

--:-:-:-:1      MOV lda,  param_lda8;
--:-:-:-:1      MOV ldb,  param_ldb8;
--:-:-:-:1      SHR.U32 lda, lda, 4;
--:-:-:-:1      SHR.U32 ldb, ldb, 4;
--:-:-:-:1      MOV ldaz, param_ldaz;
--:-:-:-:1      MOV ldbz, param_ldbz;

// trackA += (blkA*128 + lda*tidY + tidX) * 2
02:-:-:-:1      ISCADD   txa, blkA, tidX, 7;
--:-:-:-:1      XMAD.LO2 ta,  lda,  tidY, txa;
08:-:-:-:1      XMAD.LO2 ta,  ldaz, blkZ, ta;
--:-:-:-:1      LEA      trackA0.CC, ta, param_A[0],     0x1;
--:-:-:-:1      LEA.HI.X trackA1,    ta, param_A[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P5, PT, txa, param_m, PT;

// trackB += (blkB*128 + ldb*tidY + tidX) * 2
04:-:-:-:1      ISCADD   txb, blkB, tidX, 7;
--:-:-:-:1      XMAD.LO2 tb,  ldb,  tidY, txb;
08:-:-:-:1      XMAD.LO2 tb,  ldbz, blkZ, tb;
--:-:-:-:1      LEA      trackB0.CC, tb, param_B[0],     0x1;
--:-:-:-:1      LEA.HI.X trackB1,    tb, param_B[1], RZ, 0x1;

--:-:-:-:1      ISETP.LT.AND P6, PT, txb, param_n, PT;

// writeS = (128*tidY + tidX) * 4
--:-:-:-:1      ISCADD  writeS, tidY, tidX, 7;
--:-:-:-:1      ISCADD  writeS, writeS, 4x<128*8*2>, 2;


// readAs  = (((tid & 0x70) >> 3) | (tid & 1)) << 4
--:-:-:-:1      LOP.AND readAs, tid,    0x70;
--:-:-:-:1      SHR.U32 readAs, readAs, 3;
--:-:-:-:1      LOP.OR  readAs, readAs, tid1;
--:-:-:-:1      SHL     readAs, readAs, 4;

// readBs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4 + 4096;
--:-:-:-:1      BFE.U32 tid7,   tid,    0x301; // 3 bits at position 1
--:-:-:-:1      SHR.U32 readBs, tid128, 4;
--:-:-:-:1      LOP.OR  readBs, readBs, tid7;
--:-:-:-:1      ISCADD  readBs, readBs, 4x<128*8>, 4;
</SCHEDULE_BLOCK>

REMAINDER:

<SCHEDULE_BLOCK>

<CODE>
    our $vec;
    return $vec ? q{

// doLoad = tidY < k && txa|txb < n|m
--:-:-:-:1      ISETP.LT.AND P2, PT, tidY, k, P5;
--:-:-:-:1      ISETP.LT.AND P3, PT, tidY, k, P6;

--:-:2:-:1  @P2 LDG.E.CI.64 loadA, [trackA];
--:-:3:-:1  @P3 LDG.E.CI.64 loadB, [trackB];

--:-:5:-:1 @!P2 LDS.U.64 loadA, [addr_zero];
--:-:6:-:1 @!P3 LDS.U.64 loadB, [addr_zero];

    // Vec 4 and scalar loads
    } : q{

// doLoadA = tidY < k && txa < m
// doLoadB = tidY < k && txb < n
--:-:-:-:1      IADD x1, txa, 1;
--:-:-:-:1      IADD x2, txa, 2;
--:-:-:-:1      IADD x3, txa, 3;
--:-:-:-:1      ISETP.LT.AND P0, PT, tidY, k, P5;
--:-:-:-:1      ISETP.LT.AND P1, PT, x1, param_m, P0;
--:-:-:-:1      ISETP.LT.AND P2, PT, x2, param_m, P0;
--:-:-:-:1      ISETP.LT.AND P3, PT, x3, param_m, P0;

--:-:2:-:1  @P0 LDG.E.CI.S16 loadA0, [trackA + 2x<00 + 0>];
--:-:2:-:1  @P1 LDG.E.CI.S16 loadA1, [trackA + 2x<00 + 1>];
--:-:2:-:1  @P2 LDG.E.CI.S16 loadA2, [trackA + 2x<00 + 2>];
--:-:2:-:1  @P3 LDG.E.CI.S16 loadA3, [trackA + 2x<00 + 3>];

--:-:-:-:1 @!P0 MOV loadA0, RZ;
--:-:-:-:1 @!P1 MOV loadA1, RZ;
--:-:-:-:1 @!P2 MOV loadA2, RZ;
--:-:-:-:1 @!P3 MOV loadA3, RZ;

--:-:-:-:1      IADD y1, txb, 1;
--:-:-:-:1      IADD y2, txb, 2;
--:-:-:-:1      IADD y3, txb, 3;
--:-:-:-:1      ISETP.LT.AND P0, PT, tidY, k, P6;
--:-:-:-:1      ISETP.LT.AND P1, PT, y1, param_n, P0;
--:-:-:-:1      ISETP.LT.AND P2, PT, y2, param_n, P0;
--:-:-:-:1      ISETP.LT.AND P3, PT, y3, param_n, P0;

--:-:3:-:1  @P0 LDG.E.CI.S16 loadB0, [trackB + 2x<00 + 0>];
--:-:3:-:1  @P1 LDG.E.CI.S16 loadB1, [trackB + 2x<00 + 1>];
--:-:3:-:1  @P2 LDG.E.CI.S16 loadB2, [trackB + 2x<00 + 2>];
--:-:3:-:1  @P3 LDG.E.CI.S16 loadB3, [trackB + 2x<00 + 3>];

--:-:-:-:1 @!P0 MOV loadB0, RZ;
--:-:-:-:1 @!P1 MOV loadB1, RZ;
--:-:-:-:1 @!P2 MOV loadB2, RZ;
--:-:-:-:1 @!P3 MOV loadB3, RZ;

    };
</CODE>

</SCHEDULE_BLOCK>

<CODE>
    our $vec;
    return $vec ? q{
// bDoRemainder = k & 7 && k > 8
--:-:-:-:0      LOP.AND.NZ P1, RZ, k, 7;

12:-:-:-:4      F2F.F32.F16 loadA3, loadA1.H1;
--:-:-:-:0      IADD   trackA0.CC, trackA0, param_lda8;
--:-:-:-:4      F2F.F32.F16 loadA2, loadA1.H0;
--:-:-:-:4      F2F.F32.F16 loadA1, loadA0.H1;
--:-:-:-:0      IADD.X trackA1, trackA1, RZ;
--:-:2:-:2      F2F.F32.F16 loadA0, loadA0.H0;

02:-:-:-:1      STS.128 [writeS + 4x<0*128>], loadA;

24:-:-:-:4      F2F.F32.F16 loadB3, loadB1.H1;
--:-:-:-:0      IADD   trackB0.CC, trackB0, param_ldb8;
--:-:-:-:4      F2F.F32.F16 loadB2, loadB1.H0;
--:-:-:-:4      F2F.F32.F16 loadB1, loadB0.H1;
--:-:-:-:0      IADD.X trackB1, trackB1, RZ;
--:-:3:-:2      F2F.F32.F16 loadB0, loadB0.H0;

--:-:-:-:0      ISETP.GT.AND P1, PT, k, 8, P1;

04:-:-:-:1      STS.128 [writeS + 4x<8*128>], loadB;

    // scalar loads
    } : q{
--:-:-:-:0      ISETP.GT.AND P1, PT, k, 8, PT;

02:-:-:-:4      F2F.F32.F16 loadA0, loadA0;
--:-:-:-:0      IADD   trackA0.CC, trackA0, param_lda8;
--:-:-:-:4      F2F.F32.F16 loadA1, loadA1;
--:-:-:-:4      F2F.F32.F16 loadA2, loadA2;
--:-:2:-:2      F2F.F32.F16 loadA3, loadA3;

--:-:-:-:0      IADD.X trackA1, trackA1, RZ;

02:-:-:-:1      STS.128 [writeS + 4x<0*128>], loadA0;

04:-:-:-:4      F2F.F32.F16 loadB0, loadB0;
--:-:-:-:0      IADD   trackB0.CC, trackB0, param_ldb8;
--:-:-:-:4      F2F.F32.F16 loadB1, loadB1;
--:-:-:-:4      F2F.F32.F16 loadB2, loadB2;
--:-:3:-:2      F2F.F32.F16 loadB3, loadB3;

--:-:-:-:0      IADD.X trackB1, trackB1, RZ;

04:-:-:-:1      STS.128 [writeS + 4x<8*128>], loadB0;

    };
</CODE>

--:-:-:-:1      LOP.XOR readAs, readAs, 4x<128*8*2>;
--:-:-:-:0      LOP.XOR readBs, readBs, 4x<128*8*2>;
01:-:-:-:5      BAR.SYNC 0;
--:-:-:-:0      LOP.XOR writeS, writeS, 4x<128*8*2>;


<CODE>
    our $vec;
    my $k_end = $vec ? 16 : 24;
    our @top = ("--:-:-:-:1      ISETP.GE.AND P2, PT, k, $k_end, P5;\n");

    our %insert =
    (
        j0c1  => "--:-:-:-:1      ISETP.GE.AND P3, PT, k, $k_end, P6;\n",
        j0c3  => "--:-:-:-:1      ISETP.GE.AND P0, PT, k, $k_end, PT;\n",

        ($vec ? 
            (
        j0c10 => "--:-:2:-:1  \@P2 LDG.E.CI.64 loadA0, [trackA];\n",
        j0c13 => "--:-:3:-:1  \@P3 LDG.E.CI.64 loadB0, [trackB];\n",

        j5c1  => "02:-:-:-:1  \@P2 F2F.F32.F16 loadA3, loadA1.H1;\n",
        j5c5  => "--:-:-:-:1  \@P2 F2F.F32.F16 loadA2, loadA1;\n",
        j5c9  => "--:-:-:-:1  \@P2 F2F.F32.F16 loadA1, loadA0.H1;\n",
        j5c13 => "--:-:2:-:1  \@P2 F2F.F32.F16 loadA0, loadA0;\n",

        j6c1  => "04:-:-:-:1  \@P3 F2F.F32.F16 loadB3, loadB1.H1;\n",
        j6c5  => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB2, loadB1;\n",
        j6c9  => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB1, loadB0.H1;\n",
        j6c13 => "--:-:3:-:1  \@P3 F2F.F32.F16 loadB0, loadB0;\n",
            ) : 
            (
        j0c10 => "--:-:2:-:1  \@P2 LDG.E.CI.S16 loadA0, [trackA + 2x<0>];\n",
        j0c12 => "--:-:2:-:1  \@P2 LDG.E.CI.S16 loadA1, [trackA + 2x<1>];\n",
        j0c14 => "--:-:2:-:1  \@P2 LDG.E.CI.S16 loadA2, [trackA + 2x<2>];\n",
        j0c16 => "--:-:2:-:1  \@P2 LDG.E.CI.S16 loadA3, [trackA + 2x<3>];\n",

        j0c29 => "--:-:3:-:1  \@P3 LDG.E.CI.S16 loadB0, [trackB + 2x<0>];\n",
        j0c31 => "--:-:3:-:1  \@P3 LDG.E.CI.S16 loadB1, [trackB + 2x<1>];\n",
        j0c33 => "--:-:3:-:1  \@P3 LDG.E.CI.S16 loadB2, [trackB + 2x<2>];\n",
        j0c35 => "--:-:3:-:1  \@P3 LDG.E.CI.S16 loadB3, [trackB + 2x<3>];\n",

        j5c1  => "02:-:-:-:1  \@P2 F2F.F32.F16 loadA0, loadA0;\n",
        j5c5  => "--:-:-:-:1  \@P2 F2F.F32.F16 loadA1, loadA1;\n",
        j5c9  => "--:-:-:-:1  \@P2 F2F.F32.F16 loadA2, loadA2;\n",
        j5c13 => "--:-:2:-:1  \@P2 F2F.F32.F16 loadA3, loadA3;\n",

        j6c1  => "04:-:-:-:1  \@P3 F2F.F32.F16 loadB0, loadB0;\n",
        j6c5  => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB1, loadB1;\n",
        j6c9  => "--:-:-:-:1  \@P3 F2F.F32.F16 loadB2, loadB2;\n",
        j6c13 => "--:-:3:-:1  \@P3 F2F.F32.F16 loadB3, loadB3;\n",
            )
        ),

        j5c31 => "02:-:-:-:1  \@P0 STS.128 [writeS + 4x<0*128>], loadA;\n",

        j5c46 => "--:-:-:-:1  \@P2 IADD   trackA0.CC, trackA0, param_lda8;\n",
        j5c54 => "--:-:-:-:1  \@P2 IADD.X trackA1,    trackA1, RZ;\n",

        j6c31 => "04:-:-:-:1  \@P0 STS.128 [writeS + 4x<8*128>], loadB;\n",

        j6c46 => "--:-:-:-:1  \@P3 IADD   trackB0.CC, trackB0, param_ldb8;\n",
        j6c54 => "--:-:-:-:1  \@P3 IADD.X trackB1,    trackB1, RZ;\n",

        j6c63 => "--:-:-:-:5  \@P0 BAR.SYNC 0;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR readAs, readAs, 4x<128*8*2>;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR readBs, readBs, 4x<128*8*2>;\n" .
                 "--:-:-:-:1  \@P0 LOP.XOR writeS, writeS, 4x<128*8*2>;\n" .
                 "--:-:-:-:1      IADD32I k, k, -8;\n",

        j7c63 => "--:-:-:Y:5  \@P0 BRA.U LOOP;\n" .
                 "--:-:-:Y:5  \@P1 BRA.U REMAINDER;\n",
    );
    return;
</CODE>

<INCLUDE file="nervanagpu/kernels/sass/hgemm_common_128x128.sass"/>
